Skip to content

Commit abe0d1d

Browse files
committed
Begin setting up chat history database
1 parent d8123c7 commit abe0d1d

File tree

9 files changed

+179
-12
lines changed

9 files changed

+179
-12
lines changed

llamafile/db.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
2+
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
3+
//
4+
// Copyright 2024 Mozilla Foundation
5+
//
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing, software
13+
// distributed under the License is distributed on an "AS IS" BASIS,
14+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
// See the License for the specific language governing permissions and
16+
// limitations under the License.
17+
18+
#include "db.h"
19+
#include <stdio.h>
20+
#include <string>
21+
22+
__static_yoink("llamafile/schema.sql");
23+
24+
#define SCHEMA_VERSION 1
25+
26+
namespace llamafile {
27+
namespace db {
28+
29+
static bool table_exists(sqlite3* db, const char* table_name) {
30+
const char* query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;";
31+
sqlite3_stmt* stmt;
32+
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
33+
return false;
34+
}
35+
if (sqlite3_bind_text(stmt, 1, table_name, -1, SQLITE_STATIC) != SQLITE_OK) {
36+
sqlite3_finalize(stmt);
37+
return false;
38+
}
39+
bool exists = sqlite3_step(stmt) == SQLITE_ROW;
40+
sqlite3_finalize(stmt);
41+
return exists;
42+
}
43+
44+
static bool init_schema(sqlite3* db) {
45+
FILE* f = fopen("/zip/llamafile/schema.sql", "r");
46+
if (!f)
47+
return false;
48+
std::string schema;
49+
int c;
50+
while ((c = fgetc(f)) != EOF)
51+
schema += c;
52+
fclose(f);
53+
char* errmsg = nullptr;
54+
int rc = sqlite3_exec(db, schema.c_str(), nullptr, nullptr, &errmsg);
55+
if (rc != SQLITE_OK) {
56+
if (errmsg) {
57+
fprintf(stderr, "SQL error: %s\n", errmsg);
58+
sqlite3_free(errmsg);
59+
}
60+
return false;
61+
}
62+
return true;
63+
}
64+
65+
sqlite3* open(const char* path) {
66+
sqlite3* db;
67+
int rc = sqlite3_open(path, &db);
68+
if (rc) {
69+
fprintf(stderr, "%s: can't open database: %s\n", path, sqlite3_errmsg(db));
70+
return nullptr;
71+
}
72+
char* errmsg = nullptr;
73+
if (sqlite3_exec(db, "PRAGMA journal_mode=WAL;", nullptr, nullptr, &errmsg) != SQLITE_OK) {
74+
fprintf(stderr, "Failed to set journal mode to WAL: %s\n", errmsg);
75+
sqlite3_free(errmsg);
76+
sqlite3_close(db);
77+
return nullptr;
78+
}
79+
if (sqlite3_exec(db, "PRAGMA synchronous=NORMAL;", nullptr, nullptr, &errmsg) != SQLITE_OK) {
80+
fprintf(stderr, "Failed to set synchronous to NORMAL: %s\n", errmsg);
81+
sqlite3_free(errmsg);
82+
sqlite3_close(db);
83+
return nullptr;
84+
}
85+
if (!table_exists(db, "metadata") && !init_schema(db)) {
86+
fprintf(stderr, "%s: failed to initialize database schema\n", path);
87+
sqlite3_close(db);
88+
return nullptr;
89+
}
90+
return db;
91+
}
92+
93+
void close(sqlite3* db) {
94+
sqlite3_close(db);
95+
}
96+
97+
} // namespace db
98+
} // namespace llamafile

llamafile/db.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
2+
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
3+
//
4+
// Copyright 2024 Mozilla Foundation
5+
//
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing, software
13+
// distributed under the License is distributed on an "AS IS" BASIS,
14+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
// See the License for the specific language governing permissions and
16+
// limitations under the License.
17+
18+
#pragma once
19+
#include "third_party/sqlite/sqlite3.h"
20+
21+
namespace llamafile {
22+
namespace db {
23+
24+
sqlite3* open(const char*);
25+
void close(sqlite3*);
26+
27+
} // namespace db
28+
} // namespace llamafile

llamafile/flags.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ bool FLAG_tinyblas = false;
5353
bool FLAG_trace = false;
5454
bool FLAG_unsecure = false;
5555
const char *FLAG_chat_template = "";
56+
const char *FLAG_db = nullptr;
5657
const char *FLAG_file = nullptr;
5758
const char *FLAG_ip_header = nullptr;
5859
const char *FLAG_listen = "127.0.0.1:8080";
@@ -185,6 +186,13 @@ void llamafile_get_flags(int argc, char **argv) {
185186
continue;
186187
}
187188

189+
if (!strcmp(flag, "--db")) {
190+
if (i == argc)
191+
missing("--db");
192+
FLAG_db = argv[i++];
193+
continue;
194+
}
195+
188196
//////////////////////////////////////////////////////////////////////
189197
// server flags
190198

llamafile/llamafile.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ extern bool FLAG_trace;
2424
extern bool FLAG_trap;
2525
extern bool FLAG_unsecure;
2626
extern const char *FLAG_chat_template;
27+
extern const char *FLAG_db;
2728
extern const char *FLAG_file;
2829
extern const char *FLAG_ip_header;
2930
extern const char *FLAG_listen;

llamafile/schema.sql

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
CREATE TABLE metadata (
2+
key TEXT PRIMARY KEY,
3+
value TEXT
4+
);
5+
6+
CREATE TABLE chats (
7+
id INTEGER PRIMARY KEY AUTOINCREMENT,
8+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
9+
model TEXT,
10+
title TEXT
11+
);
12+
13+
CREATE TABLE messages (
14+
id INTEGER PRIMARY KEY AUTOINCREMENT,
15+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
16+
chat_id INTEGER,
17+
role TEXT,
18+
message TEXT,
19+
temperature REAL,
20+
top_p REAL,
21+
presence_penalty REAL,
22+
frequency_penalty REAL,
23+
FOREIGN KEY (chat_id) REFERENCES chats(id)
24+
);

third_party/sqlite/BUILD.mk

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33

44
PKGS += THIRD_PARTY_SQLITE
55

6+
THIRD_PARTY_SQLITE_SRCS = \
7+
third_party/sqlite/sqlite3.c \
8+
third_party/sqlite/shell.c \
9+
10+
THIRD_PARTY_SQLITE_HDRS = \
11+
third_party/sqlite/sqlite3.h \
12+
613
o/$(MODE)/third_party/sqlite/sqlite.a: \
714
o/$(MODE)/third_party/sqlite/sqlite3.o \
815

third_party/sqlite/README.llamafile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ LICENSE
1313
LOCAL CHANGES
1414

1515
- Renamed <zlib.h> to <third_party/zlib/zlib.h>
16+
- Mangled some quoted includes to not confuse mkdeps

third_party/sqlite/shell.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ typedef sqlite3_int64 i64;
123123
typedef sqlite3_uint64 u64;
124124
typedef unsigned char u8;
125125
#if SQLITE_USER_AUTHENTICATION
126-
# include "sqlite3userauth.h"
126+
# includez "sqlite3userauth.h"
127127
#endif
128128
#include <ctype.h>
129129
#include <stdarg.h>
@@ -169,7 +169,7 @@ typedef unsigned char u8;
169169

170170
#elif HAVE_LINENOISE
171171

172-
# include "linenoise.h"
172+
# includez "linenoise.h"
173173
# define shell_add_history(X) linenoiseHistoryAdd(X)
174174
# define shell_read_history(X) linenoiseHistoryLoad(X)
175175
# define shell_write_history(X) linenoiseHistorySave(X)
@@ -1710,7 +1710,7 @@ static void shellAddSchemaName(
17101710
#define WIN32_LEAN_AND_MEAN
17111711
#endif
17121712

1713-
#include "windows.h"
1713+
#includez "windows.h"
17141714

17151715
/*
17161716
** We need several support functions from the SQLite core.
@@ -7996,10 +7996,10 @@ SQLITE_EXTENSION_INIT1
79967996
# include <utime.h>
79977997
# include <sys/time.h>
79987998
#else
7999-
# include "windows.h"
7999+
# includez "windows.h"
80008000
# include <io.h>
80018001
# include <direct.h>
8002-
/* # include "test_windirent.h" */
8002+
/* # includez "test_windirent.h" */
80038003
# define dirent DIRENT
80048004
# ifndef chmod
80058005
# define chmod _chmod
@@ -8945,7 +8945,7 @@ int sqlite3_fileio_init(
89458945
* redefined SQLite API calls as the above extension code does.
89468946
* Just pull in this .c to accomplish this. As a beneficial side
89478947
* effect, this extension becomes a single translation unit. */
8948-
# include "test_windirent.c"
8948+
# includez "test_windirent.c"
89498949
#endif
89508950

89518951
/************************* End ../ext/misc/fileio.c ********************/

third_party/sqlite/sqlite3.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,9 @@
280280
** disabled.
281281
*/
282282
#if defined(_HAVE_MINGW_H)
283-
# include "mingw.h"
283+
# includez "mingw.h"
284284
#elif defined(_HAVE__MINGW_H)
285-
# include "_mingw.h"
285+
# includez "_mingw.h"
286286
#endif
287287

288288
/*
@@ -13911,7 +13911,7 @@ struct fts5_api {
1391113911
** autoconf-based build
1391213912
*/
1391313913
#if defined(_HAVE_SQLITE_CONFIG_H) && !defined(SQLITECONFIG_H)
13914-
#include "sqlite_cfg.h"
13914+
#includez "sqlite_cfg.h"
1391513915
#define SQLITECONFIG_H 1
1391613916
#endif
1391713917

@@ -29996,7 +29996,7 @@ SQLITE_PRIVATE sqlite3_mutex_methods const *sqlite3DefaultMutex(void){
2999629996
/*
2999729997
** Include the primary Windows SDK header file.
2999829998
*/
29999-
#include "windows.h"
29999+
#includez "windows.h"
3000030000

3000130001
#ifdef __CYGWIN__
3000230002
# include <sys/cygwin.h>
@@ -196803,7 +196803,7 @@ SQLITE_PRIVATE int sqlite3Fts3InitTokenizer(
196803196803

196804196804
#ifdef SQLITE_TEST
196805196805

196806-
#include "tclsqlite.h"
196806+
#includez "tclsqlite.h"
196807196807
/* #include <string.h> */
196808196808

196809196809
/*
@@ -211715,7 +211715,7 @@ SQLITE_PRIVATE int sqlite3GetToken(const unsigned char*,int*); /* In the SQLite
211715211715
** found in sqliteInt.h
211716211716
*/
211717211717
#if !defined(SQLITE_AMALGAMATION)
211718-
#include "sqlite3rtree.h"
211718+
#includez "sqlite3rtree.h"
211719211719
typedef sqlite3_int64 i64;
211720211720
typedef sqlite3_uint64 u64;
211721211721
typedef unsigned char u8;

0 commit comments

Comments
 (0)